from __future__ import annotations

from pathlib import Path

import pandas as pd


def fmt(x, digits: int = 4) -> str:
    if x is None:
        return ""
    try:
        if pd.isna(x):
            return ""
    except Exception:
        pass
    try:
        v = float(x)
        return f"{v:.{digits}f}"
    except Exception:
        return str(x)


def write_tabular(df: pd.DataFrame, path: Path, align: str, digits: int = 4) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    lines = []
    lines.append(r"\begin{tabular}{" + align + "}")
    lines.append(r"\toprule")
    header_cells = []
    for c in df.columns.tolist():
        s = str(c).replace("&", r"\&").replace("%", r"\%").replace("_", r"\_")
        header_cells.append(s)
    header = " & ".join(header_cells) + r" \\"
    lines.append(header)
    lines.append(r"\midrule")
    for _, row in df.iterrows():
        cells = []
        for col in df.columns:
            val = row[col]
            if isinstance(val, str):
                # minimal escaping for LaTeX special chars
                s = val.replace("&", r"\&").replace("%", r"\%").replace("_", r"\_")
                cells.append(s)
            else:
                cells.append(fmt(val, digits=digits))
        lines.append(" & ".join(cells) + r" \\")
    lines.append(r"\bottomrule")
    lines.append(r"\end{tabular}")
    path.write_text("\n".join(lines) + "\n", encoding="utf-8")


def read_pipe_md_table(path: Path) -> pd.DataFrame:
    lines = [ln.strip() for ln in path.read_text(encoding="utf-8").splitlines() if ln.strip()]
    data = [ln for ln in lines if ln.startswith("|")]
    if len(data) < 3:
        return pd.DataFrame()
    header = [c.strip() for c in data[0].strip("|").split("|")]
    rows = []
    for ln in data[2:]:
        cells = [c.strip() for c in ln.strip("|").split("|")]
        if len(cells) != len(header):
            continue
        rows.append(cells)
    df = pd.DataFrame(rows, columns=header)
    for c in df.columns:
        if c == "label":
            continue
        df[c] = pd.to_numeric(df[c], errors="coerce")
    return df


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "paper" / "tables"

    def read_csv_optional(path: Path) -> pd.DataFrame:
        if not path.exists():
            print(f"Skip (missing): {path}")
            return pd.DataFrame()
        return pd.read_csv(path)

    # Shift-share 2SLS main
    main_2sls = pd.read_csv(root / "outputs" / "iv_shiftshare_absorb_main_table.csv")
    main_2sls = main_2sls[main_2sls["nobs"] > 0].copy()
    main_2sls = main_2sls[
        [
            "label",
            "nobs",
            "coef_lowEdu",
            "se_lowEdu",
            "coef_highEdu",
            "coef_deltaHighMinusLow",
            "se_deltaHighMinusLow",
            "fsF_netusoft",
            "fsF_netusoft_x_edu",
            "AR_pvalue",
        ]
    ].copy()
    main_2sls = main_2sls.rename(
        columns={
            "label": "Outcome",
            "nobs": "N",
            "coef_lowEdu": "b(LowEdu)",
            "se_lowEdu": "se(LowEdu)",
            "coef_highEdu": "b(HighEdu)",
            "coef_deltaHighMinusLow": "Delta(High-Low)",
            "se_deltaHighMinusLow": "se(Delta)",
            "fsF_netusoft": "F(net)",
            "fsF_netusoft_x_edu": "F(net×edu)",
            "AR_pvalue": "AR p",
        }
    )
    write_tabular(main_2sls, out_dir / "iv_shiftshare_main_2sls.tex", align="lrrrrrrrrr", digits=4)

    # Shift-share LIML main
    main_liml = read_csv_optional(root / "outputs" / "iv_shiftshare_absorb_liml_main_table.csv")
    if not main_liml.empty:
        main_liml = main_liml[main_liml["nobs"] > 0].copy()
        main_liml = main_liml[
            [
                "label",
                "nobs",
                "coef_lowEdu",
                "se_lowEdu",
                "coef_highEdu",
                "coef_deltaHighMinusLow",
                "se_deltaHighMinusLow",
                "fsF_netusoft",
                "fsF_netusoft_x_edu",
                "AR_pvalue",
            ]
        ].copy()
        main_liml = main_liml.rename(
            columns={
                "label": "Outcome",
                "nobs": "N",
                "coef_lowEdu": "b(LowEdu)",
                "se_lowEdu": "se(LowEdu)",
                "coef_highEdu": "b(HighEdu)",
                "coef_deltaHighMinusLow": "Delta(High-Low)",
                "se_deltaHighMinusLow": "se(Delta)",
                "fsF_netusoft": "F(net)",
                "fsF_netusoft_x_edu": "F(net×edu)",
                "AR_pvalue": "AR p",
            }
        )
        write_tabular(main_liml, out_dir / "iv_shiftshare_main_liml.tex", align="lrrrrrrrrr", digits=4)

    # Placebo + mechanism (from md) — optional (may not exist in minimal runs).
    pm_md = root / "outputs" / "iv_shiftshare_absorb_placebo_mechanism.md"
    pm = read_pipe_md_table(pm_md) if pm_md.exists() else pd.DataFrame()
    if not pm.empty:
        keep = [
            "label",
            "nobs",
            "coef_lowEdu",
            "se_lowEdu",
            "coef_highEdu",
            "coef_deltaHighMinusLow",
            "se_deltaHighMinusLow",
            "fsF_netusoft",
            "fsF_netusoft_x_edu",
            "AR_pvalue",
        ]
        pm = pm[[c for c in keep if c in pm.columns]].copy()
        pm = pm.rename(
            columns={
                "label": "Outcome",
                "nobs": "N",
                "coef_lowEdu": "b(LowEdu)",
                "se_lowEdu": "se(LowEdu)",
                "coef_highEdu": "b(HighEdu)",
                "coef_deltaHighMinusLow": "Delta(High-Low)",
                "se_deltaHighMinusLow": "se(Delta)",
                "fsF_netusoft": "F(net)",
                "fsF_netusoft_x_edu": "F(net×edu)",
                "AR_pvalue": "AR p",
            }
        )
        write_tabular(pm, out_dir / "iv_shiftshare_placebo_mech.tex", align="lrrrrrrrrr", digits=4)

    # OLS vs IV (csv)
    comp = pd.read_csv(root / "outputs" / "iv_shiftshare_ols_vs_iv.csv")
    comp = comp[
        [
            "label",
            "model",
            "nobs",
            "coef_lowEdu",
            "se_lowEdu",
            "coef_highEdu",
            "coef_deltaHighMinusLow",
            "se_deltaHighMinusLow",
            "fsF_netusoft",
            "fsF_netusoft_x_edu",
            "AR_pvalue",
        ]
    ].copy()
    comp = comp.rename(
        columns={
            "label": "Outcome",
            "model": "Model",
            "nobs": "N",
            "coef_lowEdu": "b(LowEdu)",
            "se_lowEdu": "se(LowEdu)",
            "coef_highEdu": "b(HighEdu)",
            "coef_deltaHighMinusLow": "Delta(High-Low)",
            "se_deltaHighMinusLow": "se(Delta)",
            "fsF_netusoft": "F(net)",
            "fsF_netusoft_x_edu": "F(net×edu)",
            "AR_pvalue": "AR p",
        }
    )
    write_tabular(comp, out_dir / "iv_shiftshare_ols_vs_iv.tex", align="llrrrrrrrrr", digits=4)

    # Bartik appendix
    bartik = read_csv_optional(root / "outputs" / "iv_bartik_absorb_main_table.csv")
    if not bartik.empty:
        bartik = bartik[bartik["nobs"] > 0].copy()
        bartik = bartik[
            [
                "label",
                "nobs",
                "coef_lowEdu",
                "se_lowEdu",
                "coef_highEdu",
                "coef_deltaHighMinusLow",
                "se_deltaHighMinusLow",
                "fsF_netusoft",
                "fsF_netusoft_x_edu",
            ]
        ].copy()
        bartik = bartik.rename(
            columns={
                "label": "Outcome",
                "nobs": "N",
                "coef_lowEdu": "b(LowEdu)",
                "se_lowEdu": "se(LowEdu)",
                "coef_highEdu": "b(HighEdu)",
                "coef_deltaHighMinusLow": "Delta(High-Low)",
                "se_deltaHighMinusLow": "se(Delta)",
                "fsF_netusoft": "F(net)",
                "fsF_netusoft_x_edu": "F(net×edu)",
            }
        )
        write_tabular(bartik, out_dir / "iv_bartik_main.tex", align="lrrrrrrrr", digits=4)

    print(f"Wrote LaTeX tables to: {out_dir}")


if __name__ == "__main__":
    main()
